标记数据 (Flagging)
数据标记是 Gradio 中一个重要的功能,允许用户在使用您的机器学习模型时,对特定的输入-输出对进行标记。这种功能在收集反馈、识别模型错误或构建标注数据集时非常有用。
标记按钮介绍
默认情况下,每个 Gradio Interface
的输出组件下方都会显示一个"标记"(Flag)按钮。当用户看到有趣的、意外的或错误的结果时,他们可以点击这个按钮,将当前的输入和输出数据发送回运行演示的服务器。这些标记的数据可以被保存用于后续的模型改进。
配置标记行为
在 gr.Interface
构造函数中,有四个关键参数用于控制标记功能的行为:
1. flagging_mode
该参数控制标记按钮的显示方式和标记行为:
"manual"
(默认): 用户将看到一个标记按钮,只有在点击按钮时才会标记样本。"auto"
: 用户不会看到标记按钮,但每个提交的样本都会自动标记。"never"
: 禁用标记功能,用户不会看到标记按钮,也不会标记任何样本。
python
import gradio as gr
def calculator(num1, operation, num2):
if operation == "add":
return num1 + num2
elif operation == "subtract":
return num1 - num2
elif operation == "multiply":
return num1 * num2
elif operation == "divide":
if num2 == 0:
raise gr.Error("不能除以零!")
return num1 / num2
# 自动标记所有提交的计算
demo = gr.Interface(
calculator,
["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
"number",
flagging_mode="auto"
)
demo.launch()
2. flagging_options
该参数允许您自定义标记的原因选项:
- 如果为
None
(默认),用户只需点击"标记"按钮,不显示其他选项。 - 如果提供字符串列表,用户将看到多个标记按钮,每个按钮对应提供的字符串。例如,
["不正确", "有歧义"]
将显示"标记为不正确"和"标记为有歧义"按钮。 - 用户选择的选项将与输入和输出一起记录在标记数据中。
python
# 提供标记原因选项
demo = gr.Interface(
calculator,
["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
"number",
flagging_mode="manual",
flagging_options=["计算错误", "除以零", "其他问题"]
)
也可以使用元组列表提供自定义的标签和值:
python
# 使用自定义标签和值
demo = gr.Interface(
calculator,
["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
"number",
flagging_mode="manual",
flagging_options=[
("结果不正确", "incorrect_result"),
("操作不支持", "unsupported_operation"),
("其他问题", "other_issue")
]
)
3. flagging_dir
该参数指定存储标记数据的目录:
python
demo = gr.Interface(
calculator,
["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
"number",
flagging_dir="./my_flagged_data" # 自定义标记数据存储目录
)
如果不指定,默认为 ./.gradio/flagged/
。
4. flagging_callback
此参数允许您使用自定义的回调函数来处理标记的数据,而不是默认的 CSV 记录方式:
python
import gradio as gr
from gradio.flagging import FlaggingCallback
class MyCustomFlaggingCallback(FlaggingCallback):
def setup(self, components, flagging_dir):
# 初始化设置,例如连接到数据库
self.log_file = open(f"{flagging_dir}/custom_logs.txt", "a")
return self
def flag(self, flag_data, flag_option=None):
# 处理标记数据
data_str = ", ".join([str(d) for d in flag_data])
if flag_option:
self.log_file.write(f"标记原因: {flag_option}, 数据: {data_str}\n")
else:
self.log_file.write(f"数据: {data_str}\n")
self.log_file.flush()
return
demo = gr.Interface(
calculator,
["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
"number",
flagging_callback=MyCustomFlaggingCallback()
)
标记数据的存储格式
当用户点击标记按钮时,数据将按照以下方式存储:
基本数据存储
对于基本的原始数据(数字、文本等),数据将存储在一个 CSV 文件中:
# <flagging_dir>/logs.csv
num1,operation,num2,Output,timestamp
5,add,7,12,2022-01-31 11:40:51.093412
6,subtract,1.5,4.5,2022-01-31 03:25:32.023542
文件数据存储
如果您的接口包含文件类型的输入或输出(如图像、音频等),这些文件将单独保存,CSV 文件中只存储文件路径:
# 目录结构
+-- flagged/
| +-- logs.csv
| +-- image/
| | +-- 0.png
| | +-- 1.png
| +-- Output/
| | +-- 0.png
| | +-- 1.png
# <flagging_dir>/logs.csv
image,Output,timestamp
image/0.png,Output/0.png,2022-02-04 19:49:58.026963
image/1.png,Output/1.png,2022-02-02 10:40:51.093412
带有标记选项的数据
如果您使用了 flagging_options
,被选择的选项也会记录在 CSV 文件中:
# <flagging_dir>/logs.csv
num1,operation,num2,Output,flag,timestamp
5,add,7,-12,计算错误,2022-02-04 11:40:51.093412
6,subtract,1.5,3.5,其他问题,2022-02-04 11:42:32.062512
在 Blocks 中使用标记功能
在 gr.Blocks()
中,您也可以实现标记功能,但这需要手动设置:
python
import gradio as gr
import numpy as np
def sepia(input_img):
sepia_filter = np.array([
[0.393, 0.769, 0.189],
[0.349, 0.686, 0.168],
[0.272, 0.534, 0.131]
])
sepia_img = input_img.dot(sepia_filter.T)
sepia_img /= sepia_img.max()
return sepia_img
# 创建 CSV 记录器
callback = gr.CSVLogger()
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
img_input = gr.Image()
transform_btn = gr.Button("应用滤镜")
img_output = gr.Image()
flag_btn = gr.Button("标记此结果")
# 设置 logger
callback.setup([img_input, img_output], "flagged_data")
# 设置事件
transform_btn.click(sepia, inputs=img_input, outputs=img_output)
# 标记按钮点击事件
flag_btn.click(
lambda img_in, img_out: callback.flag([img_in, img_out]),
[img_input, img_output],
None,
preprocess=False
)
demo.launch()
标记数据的应用
通过标记功能收集的数据可以用于多种目的:
- 识别模型错误:标记功能可以帮助您收集模型表现不佳的数据点。
- 创建测试集:将收集到的难以处理的样本组织成测试集,用于模型评估。
- 改进模型:使用标记的数据进行模型的再训练或微调。
- 数据审计:检查模型在不同输入上的表现,识别潜在的偏见。
隐私考虑
使用标记功能时,请确保:
- 告知用户他们的数据何时会被保存。
- 明确您将如何使用这些标记的数据。
- 在使用
flagging_mode="auto"
时尤其要注意,因为所有用户提交的数据都会被自动保存。
结论
Gradio 的标记功能是一个强大的工具,可以帮助您收集反馈并改进模型。通过合理配置标记选项,您可以收集到有针对性的反馈,更有效地识别和解决模型中的问题。
在下一章中,我们将介绍如何在 Interface 中管理状态,这对于创建具有记忆功能的应用程序至关重要。